import abc
from typing import Tuple, List

import torch
import torch.nn as nn

from .depthwise_sep_conv import DepthWiseSeperableConv
from .multidilated_conv import MultidilatedConv


class BaseDiscriminator(nn.Module):
    @abc.abstractmethod
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Predict scores and get intermediate activations. Useful for feature matching loss
        :return tuple (scores, list of intermediate activations)
        """
        raise NotImplemented()


def get_conv_block_ctor(kind="default"):
    if not isinstance(kind, str):
        return kind
    if kind == "default":
        return nn.Conv2d
    if kind == "depthwise":
        return DepthWiseSeperableConv
    if kind == "multidilated":
        return MultidilatedConv
    raise ValueError(f"Unknown convolutional block kind {kind}")


def get_norm_layer(kind="bn"):
    if not isinstance(kind, str):
        return kind
    if kind == "bn":
        return nn.BatchNorm2d
    if kind == "in":
        return nn.InstanceNorm2d
    raise ValueError(f"Unknown norm block kind {kind}")


def get_activation(kind="tanh"):
    if kind == "tanh":
        return nn.Tanh()
    if kind == "sigmoid":
        return nn.Sigmoid()
    if kind is False:
        return nn.Identity()
    raise ValueError(f"Unknown activation kind {kind}")


class SimpleMultiStepGenerator(nn.Module):
    def __init__(self, steps: List[nn.Module]):
        super().__init__()
        self.steps = nn.ModuleList(steps)

    def forward(self, x):
        cur_in = x
        outs = []
        for step in self.steps:
            cur_out = step(cur_in)
            outs.append(cur_out)
            cur_in = torch.cat((cur_in, cur_out), dim=1)
        return torch.cat(outs[::-1], dim=1)


def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
    if kind == "convtranspose":
        return [
            nn.ConvTranspose2d(
                min(max_features, ngf * mult),
                min(max_features, int(ngf * mult / 2)),
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            ),
            norm_layer(min(max_features, int(ngf * mult / 2))),
            activation,
        ]
    elif kind == "bilinear":
        return [
            nn.Upsample(scale_factor=2, mode="bilinear"),
            DepthWiseSeperableConv(
                min(max_features, ngf * mult),
                min(max_features, int(ngf * mult / 2)),
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            norm_layer(min(max_features, int(ngf * mult / 2))),
            activation,
        ]
    else:
        raise Exception(f"Invalid deconv kind: {kind}")
